{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9d435bae-c85f-4368-8cd5-0eff0928458e",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bba3631e-5016-40f4-bd46-cc91e7509f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,math,sys,torch,re,numpy as np\n",
    "from types import SimpleNamespace as ns\n",
    "from collections import namedtuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbda41fd-dbbf-47d4-807a-67ad565b3bc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b57350f-3ff6-4e10-8635-040c3736220d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dim3(x=2, y=3, z=1)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = dim3(2,3)\n",
    "d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca90d679-3fba-4903-8e14-7ef9efb3bf89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 3)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d.x,d.y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e41709-f1f3-40c1-aa20-bd19737f3d86",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(precision=2, linewidth=140)\n",
    "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db47935d-8477-4116-9538-369c759322bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.insert(0, '..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a76aa950-6f87-452d-b048-82da11be0b24",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import show_img,load_cuda,cuda_begin,cdiv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "994f69fd-3989-46e2-ad84-71b7450f1b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext wurlitzer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed0d99b2-bf34-4e9c-99e5-0ebf4e4b02d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# os.environ['CUDA_LAUNCH_BLOCKING']='1'\n",
    "torch.manual_seed(42);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc38ccaa-c802-46c2-962a-e1f83eba49d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "m1 = torch.rand(5120, 256)\n",
    "m1s = m1[:4]\n",
    "m2 = torch.rand(256,5120)\n",
    "m2s = m2[:,:4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28662938-ef33-410b-96d4-d7d503c696d6",
   "metadata": {},
   "source": [
    "## Reminder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c9421c9-6cd5-479e-b3f2-7a3a8d2a7b43",
   "metadata": {},
   "source": [
    "### 2d Python kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ef592ed-b605-46f5-b72d-662ab46a55e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blk_kernel2d(f, blocks, threads, *args):\n",
    "    for i0 in range(blocks.y):\n",
    "        for i1 in range(blocks.x):\n",
    "            for j0 in range(threads.y):\n",
    "                for j1 in range(threads.x): f(dim3(i1,i0), dim3(j1,j0), threads, *args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f8923f-8072-4f52-a644-6576f7dfe352",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):\n",
    "    r = blockIdx.y*blockDim.y + threadIdx.y\n",
    "    c = blockIdx.x*blockDim.x + threadIdx.x\n",
    "    \n",
    "    if (r>=h or c>=w): return\n",
    "    o = 0.\n",
    "    for i in range(k): o += m[r*k+i] * n[i*w+c]\n",
    "    out[r*w+c] = o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a53c2202-1169-4ac6-a477-82b82f3c5201",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_2d(m, n):\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "    assert k==k2, \"Size mismatch!\"\n",
    "    output = torch.zeros(h, w, dtype=m.dtype)\n",
    "    tpb = dim3(16,16)\n",
    "    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))\n",
    "    blk_kernel2d(matmul_bk, blocks, tpb,\n",
    "                 m.flatten(), n.flatten(), output.flatten(), h, w, k)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e54824bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8023ed5e-6adb-4c00-b234-a80bf774baad",
   "metadata": {},
   "source": [
    "### CUDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f85d6e4-bc36-4171-b64e-3447359913e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {\n",
    "    int r = blockIdx.y*blockDim.y + threadIdx.y;\n",
    "    int c = blockIdx.x*blockDim.x + threadIdx.x;\n",
    "\n",
    "    if (r>=h || c>=w) return;\n",
    "    float o = 0;\n",
    "    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];\n",
    "    out[r*w+c] = o;\n",
    "}\n",
    "\n",
    "torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h = m.size(0);\n",
    "    int w = n.size(1);\n",
    "    int k = m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch!\");\n",
    "    auto output = torch::zeros({h, w}, m.options());\n",
    "\n",
    "    dim3 tpb(16,16);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "    matmul_k<<<blocks, tpb>>>(\n",
    "        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a82414f-bd5a-4fb1-9f7f-1f404b8d8a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = 'matmul'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d93c0b9a-2f20-460c-8393-3fb241c1ae85",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sig(fname, src):\n",
    "    res = re.findall(rf'^(.+\\s+{fname}\\(.*?\\))\\s*{{?\\s*$', src, re.MULTILINE)\n",
    "    return res[0]+';' if res else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b943b71-641c-4ec1-b336-7d04b4930c1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpp_src = get_sig(fname, cuda_src)\n",
    "cpp_src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9533561e-6426-4c29-92a6-3f968521d795",
   "metadata": {},
   "outputs": [],
   "source": [
    "module = load_cuda(cuda_src, cpp_src, [fname])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d4e99ab-d2a7-45ef-a38f-f38ca17e22d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "m1c,m2c = m1.contiguous().cuda(),m2.contiguous().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08507713-c7b5-40b4-a7c0-a9ead64f3017",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5120, 5120])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "module.matmul(m1c,m2c).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80fc299e-8c08-40f2-9c17-cc1965b0b430",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(module.matmul(m1c,m2c), m1c@m2c).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e2c3eb-813e-4e69-9d90-731a8bf321aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit -n 10\n",
    "module.matmul(m1c,m2c)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "836bd00e-8c17-4208-a159-aa91a00425ae",
   "metadata": {},
   "source": [
    "When I removed the call to the kernel itself, it took around 50 µs (0.05 ms) to run, so that's the overhead of the call on my machine."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32b1cf32-baa7-449b-ac0b-9d787d7bc470",
   "metadata": {},
   "source": [
    "## Shared mem"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed2bd620-dc68-4347-9ac5-3b4335b788dc",
   "metadata": {},
   "source": [
    "### Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aa3d263-a704-4c67-8718-0ff5c6f5ca81",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.zeros(5)\n",
    "b,c = a[:3],a[3:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77a8ed5a-72f6-4509-8a9d-9941bfa33271",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 2., 0., 6., 0.])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[1] = 2\n",
    "c[0] = 6\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "389a066a-216f-40eb-90d5-4723ccb42b9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):\n",
    "    for i0 in range(blocks.y):\n",
    "        for i1 in range(blocks.x):\n",
    "            shared = torch.zeros(sh_sz)\n",
    "            f(dim3(i1,i0), threads, shared, *args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac969e7d-266d-48f6-9aa7-856ad5de92a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):\n",
    "    shar_sz = tw*tw\n",
    "    ms,ns = shared[:shar_sz],shared[shar_sz:]\n",
    "\n",
    "    for ph in range(cdiv(k,tw)):\n",
    "        idx = ph*tw\n",
    "        # fill shared\n",
    "        for tr in range(blockDim.y):\n",
    "            for tc in range(blockDim.x):\n",
    "                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc\n",
    "                ms[tr*tw+tc] = m[ tc+idx + r*k] if r<h and idx+tc<k else 0.\n",
    "                ns[tr*tw+tc] = n[(tr+idx)*w +c] if c<w and idx+tr<k else 0.\n",
    "\n",
    "        # do dotprods from shared\n",
    "        for tr in range(blockDim.y):\n",
    "            for tc in range(blockDim.x):\n",
    "                r,c = blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc\n",
    "                for i in range(tw):\n",
    "                    if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fb0c224-5781-4e92-9b59-36a6fa39c909",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_2d(m, n, tw=16):\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "    assert k==k2, \"Size mismatch!\"\n",
    "    output = torch.zeros(h, w, dtype=m.dtype)\n",
    "    tpb = dim3(tw,tw)\n",
    "    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))\n",
    "    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,\n",
    "                      m.flatten(), n.flatten(), output.flatten(),\n",
    "                      h, w, k, tw=tw)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7021417d-4932-46ac-838e-bbeb0fc27504",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 256]), torch.Size([256, 5120]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m1s.shape, m2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accf5bf5-0874-4a30-aeab-f01034be2a0e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d1a31cf-b54f-4b8a-982f-51ff1a68bb6a",
   "metadata": {},
   "source": [
    "### Python run_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a44f65b-3ad1-4654-92b9-5b61ede1a6c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_threads(f, blockDim, *args, **kwargs):\n",
    "    for i0 in range(blockDim.y):\n",
    "        for i1 in range(blockDim.x): f(i0, i1, *args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "096fbfc1-5f01-4203-94df-98a62ddc2ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):\n",
    "    shar_sz = tw*tw\n",
    "    ms,ns = shared[:shar_sz],shared[shar_sz:]\n",
    "\n",
    "    def get_rc(tr, tc): return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc\n",
    "\n",
    "    def fill_shared_tk(tr, tc, ph):\n",
    "        r,c = get_rc(tr, tc)\n",
    "        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.\n",
    "        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.\n",
    "\n",
    "    def dotprod_tk(tr, tc):\n",
    "        r,c = get_rc(tr, tc)\n",
    "        for i in range(tw):\n",
    "            if r*w+c<len(out): out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]\n",
    "\n",
    "    for ph in range(int(math.ceil(k/tw))):\n",
    "        run_threads(fill_shared_tk, blockDim, ph)\n",
    "        run_threads(dotprod_tk, blockDim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0ae141e-be1a-42cd-85aa-64be0add9e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_2d(m, n, tw=16):\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "    assert k==k2, \"Size mismatch!\"\n",
    "    output = torch.zeros(h, w, dtype=m.dtype)\n",
    "    tpb = dim3(tw,tw)\n",
    "    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))\n",
    "    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,\n",
    "                      m.flatten(), n.flatten(), output.flatten(),\n",
    "                      h, w, k, tw=tw)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb489de0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 256]), torch.Size([256, 4]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m1s.shape, m2s.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e65cb4e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b204c2-e349-4afb-9787-a9a045b78977",
   "metadata": {},
   "source": [
    "### Python threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6aec96-1e67-4d6d-81af-1f6c3cb6de19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "from threading import Barrier, Thread\n",
    "from concurrent.futures import ThreadPoolExecutor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28306d45-3176-4cef-95f1-5894cf494c58",
   "metadata": {},
   "outputs": [],
   "source": [
    "def g(x, sb):\n",
    "    print(x)\n",
    "    sb.wait()\n",
    "    print(-x)\n",
    "    sb.wait()\n",
    "    print(x*10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4e712e-5999-4e4b-a057-db65996b657d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "-3\n",
      "-1\n",
      "-2\n",
      "30\n",
      "20\n",
      "10\n"
     ]
    }
   ],
   "source": [
    "num = 3\n",
    "sb = Barrier(num)\n",
    "with ThreadPoolExecutor(num) as ex: list(ex.map(lambda i: g(i,sb), range(1,num+1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac722db-cea2-494f-9ad0-a46c010e61e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):\n",
    "    for i0 in range(blocks.y):\n",
    "        for i1 in range(blocks.x):\n",
    "            shar = torch.zeros(sh_sz)\n",
    "            syncb = Barrier(tpb.y*tpb.x)\n",
    "            threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)\n",
    "                       for o in range(tpb.y) for p in range(tpb.x)]\n",
    "            for tr in threads: tr.start()\n",
    "            for tr in threads: tr.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f86730-7484-404d-b542-861a5516aa72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):\n",
    "    tc,tr = threadIdx.x,threadIdx.y\n",
    "    r = blockIdx.y*blockDim.y + tr\n",
    "    c = blockIdx.x*blockDim.x + tc\n",
    "\n",
    "    shar_sz = tw*tw\n",
    "    ms,ns = shared[:shar_sz],shared[shar_sz:]\n",
    "\n",
    "    p = 0.\n",
    "    for ph in range(cdiv(k,tw)):\n",
    "        ms[tr*tw+tc] = m[ tc + ph*tw + r*k] if r<h and (ph*tw+tc)<k else 0.\n",
    "        ns[tr*tw+tc] = n[(tr + ph*tw)*w +c] if c<w and (ph*tw+tr)<k else 0.\n",
    "        syncb.wait()\n",
    "        for i in range(tw): p += ms[tr*tw+i] * ns[tw*i+tc]\n",
    "        syncb.wait()\n",
    "\n",
    "    if (r<h and c<w): out[r*w + c] = p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f758eb6b-fc56-43e0-943f-3afd8ce86dbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_2d(m, n, tw=16):\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "    assert k==k2, \"Size mismatch!\"\n",
    "    output = torch.zeros(h, w, dtype=m.dtype)\n",
    "    tpb = dim3(tw,tw)\n",
    "    blocks = dim3(cdiv(w,tpb.x), cdiv(h,tpb.y))\n",
    "    blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,\n",
    "                      m.flatten(), n.flatten(), output.flatten(),\n",
    "                      h, w, k, tw=tw)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270af122-7282-4934-833d-6fe195efc7b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5bfbea5-8cd1-41c0-aca8-298a48599b52",
   "metadata": {},
   "source": [
    "### CUDA dynamic shared"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f42f931-580e-4ced-8735-354abe6499ef",
   "metadata": {},
   "source": [
    "Code auto-generated by ChatGPT 4, using the following prompt:\n",
    "\n",
    "> Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove `blockIdx, threadIdx, blockDim, shared` from the argument list, since they're already provided by CUDA. Change `syncb.wait()` to `__syncthreads`. Use `extern __shared__ float shared[]` to create the `shared` array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.\n",
    "\n",
    "The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming `shared` to `ms`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fa19f4-8ca6-4ce0-a89c-d82a79a71d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {\n",
    "    int tc=threadIdx.x, tr=threadIdx.y;\n",
    "    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;\n",
    "\n",
    "    extern __shared__ float ms[];\n",
    "    float *ns = &ms[tw*tw];\n",
    "\n",
    "    float p = 0.0f;\n",
    "    for (int ph = 0; ph < cdiv(k,tw); ++ph) {\n",
    "        int idx = ph*tw;\n",
    "        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;\n",
    "        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;\n",
    "        __syncthreads();\n",
    "        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];\n",
    "        __syncthreads();\n",
    "    }\n",
    "    if (r<h && c<w) out[r*w + c] = p;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ef35af0-02a5-44f1-ae3e-c05258a62bd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src += r'''\n",
    "torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h=m.size(0), w=n.size(1), k=m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch!\");\n",
    "    auto output = torch::zeros({h, w}, m.options());\n",
    "\n",
    "    /*\n",
    "    // Commented out section demonstrating basic idea of dynamic size calculation\n",
    "    cudaDeviceProp devProp;\n",
    "    CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));\n",
    "    int maxThreads = devProp.maxThreadsPerBlock;\n",
    "    size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);\n",
    "    size_t size = min(devProp.sharedMemPerBlock, requiredSize);\n",
    "    int TW = std::sqrt(maxThreads);\n",
    "    */\n",
    "\n",
    "    // We just set size fixed for now\n",
    "    int TW = 16;\n",
    "    size_t size = TW*TW * 2 * sizeof(float);\n",
    "    dim3 tpb(TW,TW);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "    matmul_k<<<blocks,tpb,size>>>(\n",
    "        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c81f01-589e-41f9-899a-a2050b1b4fdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = 'matmul_dyn'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1a24b0-0cb9-46dd-97fa-4df0bca629d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "cpp_src = get_sig(fname, cuda_src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7352ea50-96ca-470e-9a16-360bf09ddd75",
   "metadata": {},
   "outputs": [],
   "source": [
    "module = load_cuda(cuda_src, cpp_src, [fname], opt=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1895e28c-8fda-403c-bd43-95a481960fbc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc98195e-0dce-4487-9d85-aec2fdadabae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.64 ms ± 319 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit -n 10\n",
    "module.matmul_dyn(m1c,m2c)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20fda413-cac6-4f4f-b362-61f6a8db4056",
   "metadata": {},
   "source": [
    "### CUDA static shared"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d03d0811-0421-46d2-b67f-dabd7557666d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "constexpr int tw = 16;\n",
    "\n",
    "__global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {\n",
    "    __shared__ float ms[tw][tw], ns[tw][tw];\n",
    "    int tc=threadIdx.x, tr=threadIdx.y;\n",
    "    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;\n",
    "\n",
    "    float p=0.0f;\n",
    "    for (int ph=0; ph < cdiv(k,tw); ++ph) {\n",
    "        int idx = ph*tw;\n",
    "        ms[tr][tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;\n",
    "        ns[tr][tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;\n",
    "        __syncthreads();\n",
    "        for (int i=0; i<tw; ++i) p += ms[tr][i] * ns[i][tc];\n",
    "        __syncthreads();\n",
    "    }\n",
    "    if (r<h && c<w) out[r*w + c] = p;\n",
    "}\n",
    "\n",
    "torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h=m.size(0), w=n.size(1), k=m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch!\");\n",
    "    auto output = torch::zeros({h, w}, m.options());\n",
    "    dim3 tpb(tw,tw);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "    matmul_ks<<<blocks,tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36bea62d-90ee-41c4-8df2-aca17911ae50",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fname = 'matmul_static'\n",
    "cpp_src = get_sig(fname, cuda_src)\n",
    "module = load_cuda(cuda_src, cpp_src, [fname])\n",
    "torch.isclose(module.matmul_static(m1c,m2c), m1c@m2c).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ab2501-a73a-48d3-9ee5-0c1e508e1fe4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.34 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit -n 10\n",
    "module.matmul_static(m1c,m2c)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe5ab0d1-f574-42e1-9d50-e9ffdcd50861",
   "metadata": {},
   "source": [
    "## Numba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99d0a300-2763-4b9d-a496-eb3c877c006b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numba import cuda\n",
    "from numba.cuda import as_cuda_array as ca"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b409312b-1dca-4e63-995d-9b519b61b58f",
   "metadata": {},
   "outputs": [],
   "source": [
    "@cuda.jit\n",
    "def matmul_k_numba(m, n, out, tw):\n",
    "    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx\n",
    "    tc,tr = tid.x,tid.y\n",
    "    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "\n",
    "    shar = cuda.shared.array(0, dtype=np.float32)\n",
    "    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]\n",
    "\n",
    "    p = np.float32(0.0)\n",
    "    for ph in range(math.ceil(k/tw)):\n",
    "        idx = ph*tw\n",
    "        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.\n",
    "        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.\n",
    "        cuda.syncthreads()\n",
    "        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]\n",
    "        cuda.syncthreads()\n",
    "    if r < h and c < w: out[r, c] = p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223cdf11-1d26-44db-8c9b-aaac3607f6b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul_2d_numba(m, n, tw=16):\n",
    "    h,k  = m.shape\n",
    "    k2,w = n.shape\n",
    "    assert k==k2, \"Size mismatch!\"\n",
    "    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)\n",
    "    dyn_shared_mem_size = 2 * tw * tw * 4\n",
    "    tpb = tw,tw\n",
    "    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])\n",
    "    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) \n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a09a6c3-e161-4ad1-ad50-c8020847b220",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(matmul_2d_numba(m1c,m2c), m1c@m2c).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d50ac1f7-0a23-4aed-9f45-0cd6db294825",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.2 ms ± 68.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit -n 10\n",
    "matmul_2d_numba(m1c,m2c)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfdd47b1-6abc-47ad-b15c-6655abcbebed",
   "metadata": {},
   "source": [
    "## Extra: Optimised Dynamic CUDA with Template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eeb1e8e-69a6-43b2-9ded-c605222a4bf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "template<int tw>\n",
    "__global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {\n",
    "    int tc=threadIdx.x, tr=threadIdx.y;\n",
    "    int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;\n",
    "    extern __shared__ float ms[];\n",
    "    float *ns = &ms[tw*tw];\n",
    "\n",
    "    float p = 0.0f;\n",
    "    for (int ph = 0; ph < cdiv(k,tw); ++ph) {\n",
    "        int idx = ph*tw;\n",
    "        ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;\n",
    "        ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;\n",
    "        __syncthreads();\n",
    "        for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];\n",
    "        __syncthreads();\n",
    "    }\n",
    "    if (r<h && c<w) out[r*w + c] = p;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a5df8d7-d58f-4481-bc6a-f39e7b866c36",
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_src += r'''\n",
    "torch::Tensor matmul_dyn1(torch::Tensor m, torch::Tensor n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h=m.size(0), w=n.size(1), k=m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch!\");\n",
    "    auto output = torch::zeros({h, w}, m.options());\n",
    "    int TW = 16; // TODO: Calculate this dynamically\n",
    "    size_t size = TW*TW*2 * sizeof(float) + 1;\n",
    "    dim3 tpb(TW,TW);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "\n",
    "    auto f = [&](auto kf) { kf<<<blocks, tpb, size>>>(\n",
    "        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);\n",
    "    };\n",
    "    switch(TW) {\n",
    "        case 8: f(matmul_k<8>); break;\n",
    "        case 16: f(matmul_k<16>); break;\n",
    "        case 32: f(matmul_k<32>); break;\n",
    "        default: break;\n",
    "    }\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f457f70c-800b-4886-b74d-616704ea31b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 93.2 ms, sys: 37.3 ms, total: 130 ms\n",
      "Wall time: 43.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "fname = 'matmul_dyn1'\n",
    "cpp_src = get_sig(fname, cuda_src)\n",
    "module = load_cuda(cuda_src, cpp_src, [fname], opt=True)\n",
    "func = getattr(module, fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dd1bfc9-819e-4909-aeef-ff340cd7ac25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(func(m1c,m2c), m1c@m2c).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89a1e8e7-238d-46c2-b518-d1e5e8fc6561",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.35 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit -n 10\n",
    "func(m1c,m2c)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35fe0c4a-1990-4b07-8c37-c7d44277d1e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
